#include <torch/torch.h>
#include <iostream>
#include <cmath>
#include <cstdio>
#include <vector>









struct Net : torch::nn::Module {
    torch::Tensor W, b;
    Net(int64_t N, int64_t M) {
        // here we register the parametet of the model
        W = register_parameter("W", torch::randn({N, M}));
        b = register_parameter("b", torch::randn(M));
    }
    torch::Tensor forward(torch::Tensor input) {
        // here is the forward function
        // std::cout << W.item() << std::endl;
        float arr[10] = {0.0};
        // torch::Tensor arr_t = torch::from_blob(arr);
        return torch::addmm(b, input, W);
    }
};






struct LinearNet : torch::nn::Module
{
    LinearNet(int64_t N, int64_t M)
        // here we register a submodule torch::nn::Linear
        // and torch::nn::Sigmoid
        : linear(register_module("linear", torch::nn::Linear(N, M))),
          sigmoid(register_module("sigmoid", torch::nn::Sigmoid()))
    {
        // here we register parameter
        another_bias = register_parameter("b", torch::randn(M));
    }
    torch::nn::Linear linear;
    torch::Tensor another_bias;
    torch::nn::Sigmoid sigmoid;

    torch::Tensor forward(torch::Tensor input)
    {
        // here is the forward function
        return sigmoid(linear(input) + another_bias);
    }
};






// The number of epochs to train.
const int64_t kNumberOfEpochs = 30;



int main() {
    torch::manual_seed(7);

    Net net(4, 5);
    for (const auto& p : net.parameters()) {
        std::cout << p << std::endl;
    }

    LinearNet ln(7, 3);
    for (const auto &pair : ln.named_parameters())
    {
        std::cout << pair.key() << ": " << pair.value() << std::endl;
    }

    torch::optim::Adam ln_optimizer(
      ln.parameters(), torch::optim::AdamOptions(2e-4).betas(std::make_tuple (0.5, 0.5)));

    for (int64_t epoch = 1; epoch <= kNumberOfEpochs; ++epoch) {
        ln.zero_grad();
        torch::Tensor fake_data = torch::randn(7);
        // fake_data = fake_data.sub(fake_data.mean()).div(fake_data.std());
        torch::Tensor fake_labels = torch::zeros(3);
        torch::Tensor fake_output = ln.forward(fake_data);
        torch::Tensor d_loss_fake = torch::binary_cross_entropy(fake_output, fake_labels);
        d_loss_fake.backward();
        ln_optimizer.step();
        if (epoch % 5 == 0)
        {
            std::printf(
                "\r[%2ld/%2ld] D_loss: %.4f\n",
                epoch,
                kNumberOfEpochs,
                d_loss_fake.item<float>());
        }
    }
}